{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Classification Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:13:57.351628Z",
     "iopub.status.busy": "2025-09-28T08:13:57.351391Z",
     "iopub.status.idle": "2025-09-28T08:14:00.065271Z",
     "shell.execute_reply": "2025-09-28T08:14:00.064578Z"
    }
   },
   "outputs": [],
   "source": [
    "from river import metrics, datasets, compose, preprocessing\n",
    "from deep_river.classification import Classifier\n",
    "from torch import nn\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:14:00.068155Z",
     "iopub.status.busy": "2025-09-28T08:14:00.067864Z",
     "iopub.status.idle": "2025-09-28T08:14:00.715713Z",
     "shell.execute_reply": "2025-09-28T08:14:00.715221Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">Classifier</pre></summary><code class=\"river-estimator-params\">Classifier (\n",
       "  module=MyModule(\n",
       "  (dense0): Linear(in_features=10, out_features=5, bias=True)\n",
       "  (nonlin): ReLU()\n",
       "  (dense1): Linear(in_features=5, out_features=2, bias=True)\n",
       "  (softmax): Softmax(dim=-1)\n",
       ")\n",
       "  loss_fn=\"binary_cross_entropy\"\n",
       "  optimizer_fn=\"adam\"\n",
       "  lr=0.001\n",
       "  output_is_logit=True\n",
       "  is_class_incremental=False\n",
       "  is_feature_incremental=False\n",
       "  device=\"cpu\"\n",
       "  seed=42\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  Classifier (\n",
       "    module=MyModule(\n",
       "    (dense0): Linear(in_features=10, out_features=5, bias=True)\n",
       "    (nonlin): ReLU()\n",
       "    (dense1): Linear(in_features=5, out_features=2, bias=True)\n",
       "    (softmax): Softmax(dim=-1)\n",
       "  )\n",
       "    loss_fn=\"binary_cross_entropy\"\n",
       "    optimizer_fn=\"adam\"\n",
       "    lr=0.001\n",
       "    output_is_logit=True\n",
       "    is_class_incremental=False\n",
       "    is_feature_incremental=False\n",
       "    device=\"cpu\"\n",
       "    seed=42\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = datasets.Phishing()\n",
    "metric = metrics.Accuracy()\n",
    "\n",
    "\n",
    "class MyModule(nn.Module):\n",
    "    def __init__(self, n_features):\n",
    "        super(MyModule, self).__init__()\n",
    "        self.dense0 = nn.Linear(n_features, 5)\n",
    "        self.nonlin = nn.ReLU()\n",
    "        self.dense1 = nn.Linear(5, 2)\n",
    "        self.softmax = nn.Softmax(dim=-1)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = self.nonlin(self.dense1(X))\n",
    "        X = self.softmax(X)\n",
    "        return X\n",
    "\n",
    "\n",
    "model_pipeline = compose.Pipeline(\n",
    "    preprocessing.StandardScaler,\n",
    "    Classifier(\n",
    "        module=MyModule(10), loss_fn=\"binary_cross_entropy\", optimizer_fn=\"adam\"\n",
    "    ),\n",
    ")\n",
    "model_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:14:00.718830Z",
     "iopub.status.busy": "2025-09-28T08:14:00.718400Z",
     "iopub.status.idle": "2025-09-28T08:14:01.884166Z",
     "shell.execute_reply": "2025-09-28T08:14:01.883275Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.564\n"
     ]
    }
   ],
   "source": [
    "for x, y in dataset.take(5000):\n",
    "    y_pred = model_pipeline.predict_one(x)  # make a prediction\n",
    "    metric.update(y, y_pred)  # update the metric\n",
    "    model_pipeline.learn_one(x, y)  # make the model learn\n",
    "print(f\"Accuracy: {metric.get()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep-river",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
